#%%
# import torch
# import collections

# collections.Iterable = collections.abc.Iterable
# collections.Mapping = collections.abc.Mapping
# collections.MutableSet = collections.abc.MutableSet
# collections.MutableMapping = collections.abc.MutableMapping

# import tltorch as tly


# c = tly.factorized_layers.factorized_convolution.FactorizedConv(3,2,(3,3),factorization='TT')

# print(c.rank)
import torch
import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import tltorch
import math

import tensorly as tly

tly.set_backend('pytorch')


class Conv2d_mat_vanilla(torch.nn.Conv2d):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=True,
                 dilation=1, start_rank_percent=0.4) -> None:
        """
        Initializer for the convolutional low rank layer (filterwise), extention of the classical Pytorch's convolutional layer.
        INPUTS:
        in_channels: number of input channels (Pytorch's standard)
        out_channels: number of output channels (Pytorch's standard)
        kernel_size : kernel_size for the convolutional filter (Pytorch's standard)
        dilation : dilation of the convolution (Pytorch's standard)
        padding : padding of the convolution (Pytorch's standard)
        stride : stride of the filter (Pytorch's standard)
        bias  : flag variable for the bias to be included (Pytorch's standard)
        step : string variable ('K','L' or 'S') for which forward phase to use
        rank : rank variable, None if the layer has to be treated as a classical Pytorch Linear layer (with weight and bias). If
                it is an int then it's either the starting rank for adaptive or the fixed rank for the layer.
        fixed : flag variable, True if the rank has to be fixed (KLS training on this layer)
        load_weights : variables to load (Pytorch standard, to finish)
        dtype : Type of the tensors (Pytorch standard, to finish)
        """
        super().__init__(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                         bias=bias, dilation=dilation)
        
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.kernel_size = [kernel_size,kernel_size] if isinstance(kernel_size,int)  else list(self.kernel_size)

        low_rank_percent = start_rank_percent
        self.dims = [self.out_channels,self.in_channels*self.kernel_size[0]*self.kernel_size[1]]#[self.out_channels, self.in_channels] + list(self.kernel_size)
        minimal_rank = 3
        # make sure that there are at least 3 channels, for rgb images
        # d1,d2,d3,d4 = self.dims
        self.rank = max(int(min(self.dims)*low_rank_percent),3)
        self.U = torch.nn.Parameter(torch.empty(size = (self.out_channels,self.rank)))
        self.V = torch.nn.Parameter(torch.empty(size = (self.in_channels*self.kernel_size[0]*self.kernel_size[1],self.rank)))
        self.Us = [self.U,self.V]

        self.reset_mat_parameters()  # parameter intitialization

    @torch.no_grad()
    def reset_mat_parameters(self):
        torch.nn.init.kaiming_uniform_(self.U, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))

            # Orthonormalize bases
            # self.Us[i], _ = torch.linalg.qr(self.Us[i], 'reduced')

    def forward(self, input):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor
        """

        weight = self.U@(self.V.T)
        weight = torch.reshape(weight,[self.out_channels,self.in_channels]+self.kernel_size)

        # print(f'check {list(weight.shape)==self.dims}')

        result = torch.nn.functional.conv2d(input, weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
        # No bias!
        return result
